#!/usr/bin/env python
# print("importing libraries")
from tartan_logging import TartanLogging
import numpy as np
import argparse
import dill
import matplotlib
# matplotlib.use('pdf')
import matplotlib.pyplot as plt
import cv2
import pylab as pl
import math
from sm import PlotCollection
from matplotlib.backends.backend_pdf import PdfPages
import cv2 
import glob, os
import pandas as pd

experiments_folder = '/data/final_experiments/'
summary_file = 'summary.txt'
nbins = 100
csv_dir = '/data/final_experiments/experiments.csv'
polar_cutoff = -1

class Experiment:
    def __init__(self,log_file):
        self.log_file = log_file
        self.dirname = os.path.dirname(self.log_file)+'/'
        self.summary_file = self.dirname+summary_file
        file1 = open(self.summary_file, 'r')
        Lines = file1.readlines()
        self.bag = Lines[0]
        self.target = Lines[1]
        self.num_cams = int(Lines[2])
        self.model = Lines[3]

def get_outliers(array, num_stds):
    num_points = np.shape(array)[1]
    std = np.std(array, 0, dtype=np.float)
    delete_list = []
    for i in range(num_points):
        if (np.abs(array[0,i]) > num_stds*std[0] or np.abs(array[1,i]) > num_stds*std[1]):
            delete_list.append(i)
    
    return delete_list

def evaluate(cams, obsdb):
    try: 
        err = np.array([])
        polars = [] 
        total_num_points = 0
        for i, cam in enumerate(cams):

            # delete_idxs = get_outliers(err,4)
            # err = np.delete(err,delete_idxs,1)
            err_local = cam.getReprojectionErrors(obsdb[i])
            num_points = np.shape(err_local)[0]
            err_local = np.transpose(np.reshape(err_local,(num_points,2)))
            err_local = np.array([np.linalg.norm(err_local[:,q]) for q in range(num_points)])
            
            err = np.append(err,err_local)
            # mag_err = np.array([np.linalg.norm(err[:,q]) for q in range(num_points)])

            # polar normalized errors
            img_points_x = np.concatenate([obs.getCornersImageFrame()[:,0] for obs in obsdb[i] ])
            img_points_y = np.concatenate([obs.getCornersImageFrame()[:,1] for obs in obsdb[i] ])

            for j in range(num_points):
                ec_pnt = cam.keypointToEuclidean(np.array([[img_points_x[j]],[img_points_y[j]]])) 
                polars.append(np.rad2deg(np.arctan2(np.linalg.norm(ec_pnt[:2]),ec_pnt[-1])))
            
            total_num_points += num_points
        
        # polars = np.array(polars)
        # polars = np.delete(polars,delete_idxs,1)
        # if (polar_cutoff != -1):
        #     delete_idxs = (polars >polar_cutoff).nonzero()
        #     polars = np.delete(polars,delete_idxs)
        #     err = np.delete(err,delete_idxs)

        polar_idxs = np.argsort(polars)

        polar_min = np.min(polars)
        polar_max = np.max(polars)
        bins = np.linspace(polar_min,polar_max,nbins+1)
        bin_idxs = np.digitize(polars,bins)-1
        bin_avg_errs = np.zeros(nbins)
        empty_idxs = []
        full_idxs = []
        for k in range(nbins):
            if len(np.where(bin_idxs==k)[0]!=0):
                bin_avg_errs[k] = np.mean(err[np.where(bin_idxs==k)])
                full_idxs.append(k)
            else:
                empty_idxs.append(k)
        
        bin_avg_errs = np.delete(bin_avg_errs,empty_idxs)


        return [np.mean(err),np.mean(bin_avg_errs),total_num_points]
        
    except:
        print("Something went wrong")
        return [0,0,0]

        # num_points = np.shape(err)[0]
        # num_points = len(err)
        # err = np.transpose(np.reshape(err,(num_points,2)))
        # mag_err = np.array([np.linalg.norm(err[:,q]) for q in range(num_points)])
        # # remove outliers
        

        # # get polar angles
        # img_points_x = np.concatenate([obs.getCornersImageFrame()[:,0] for obs in obslist ])
        # img_points_y = np.concatenate([obs.getCornersImageFrame()[:,1] for obs in obslist ])

        # counter = 0

        # polars = []
        # for j in range(num_points):
        #     ec_pnt = cam.keypointToEuclidean(np.array([[img_points_x[j]],[img_points_y[j]]])) 
        #     polars.append(np.rad2deg(np.arctan2(np.linalg.norm(ec_pnt[:2]),ec_pnt[-1])))

        # polars = np.array(polars)
        

        # polar_idxs = np.argsort(polars)
        # # print("Stats for {0}, mean: {1}, std:{2}.".format(camNames[i],str(np.mean(np.array(err), 1, dtype=np.float)),str(np.std(np.array(err), 1, dtype=np.float))))
        # err = np.array([np.linalg.norm(err[:,q]) for q in range(num_points)])

        # polar_min = np.min(polars)
        # polar_max = np.max(polars)
        # bins = np.linspace(polar_min,polar_max,parsed.nbins+1)
        # bin_idxs = np.digitize(polars,bins)-1
        # bin_avg_errs = np.zeros(parsed.nbins)
        # empty_idxs = []
        # full_idxs = []
        # for k in range(parsed.nbins):
        #     if len(np.where(bin_idxs==k)[0]!=0):
        #         bin_avg_errs[k] = np.mean(err[np.where(bin_idxs==k)])
        #         full_idxs.append(k)
        #     else:
        #         empty_idxs.append(k)
        
        # bin_avg_errs = np.delete(bin_avg_errs,empty_idxs)



def get_metrics(experiment):
    # load the cameras and observations
    with open(experiment.log_file, 'r') as f:
        tartanObsLog = dill.load(f)
    
    tartanObsdb = [log.ObservationDatabase_ for log in tartanObsLog]
    
    kalibrobsdb = tartanObsdb[0]
    tartanobsdb = tartanObsdb[-1]

    kalibrobs = [kalibrobsdb.observations[i] for i in range(experiment.num_cams)]
    kalibrcams = [tartanObsLog[0].camera_[i] for i in range(experiment.num_cams)]

    kalibr_err = evaluate(kalibrcams,kalibrobs)

    try: 
        tartanobs = [tartanobsdb.observations[i] for i in range(experiment.num_cams)]
        tartancams = [tartanObsLog[1].camera_[i] for i in range(experiment.num_cams)]

        tartan_err = evaluate(tartancams,tartanobs)
        kalibr_in_tartan_err = evaluate(kalibrcams,tartanobs)

        return kalibr_err[:2] + tartan_err[:2] + kalibr_in_tartan_err[:2] + [kalibr_err[-1]] + [tartan_err[-1]]
    
    except:
        return kalibr_err[:2] + [0,0] + [0,0] + [kalibr_err[-1]] + [0]

    # print("obs iteration:"+str(parsed.obs_it))
    # evalObsDb = tartanObs[int(parsed.obs_it)]
    # evalObs = [evalObsDb.observations[int(parsed.camids[i])] for i in range(num_evals)]

    # evalCams = []
    
    # if type(parsed.camids) != list:
    #     cam_list = []
    #     for _ in range(num_evals):
    #         cam_list.append(int(parsed.camids))
    # else:
    #     cam_list = list(map(int,parsed.camids))
    
    # for i, camlog in enumerate(parsed.cams):
    #     with open(camlog, 'r') as f:
    #         camlog_loaded = dill.load(f)

    #     evalCams.append(camlog_loaded[int(parsed.cams_it[i])].camera_[cam_list[i]])


#     num_evals = len(evalCams)
#     if(parsed.babel_args != 0):
#         num_evals += 1

#     for i in range(num_evals):
#         if (i == (num_evals-1) and parsed.babel_args != 0):
#             cam = evalCams[i-1]
#             obslist = evalObs[i-1]
#             err = np.array(cam.getReprojectionErrors(evalObs[i-1]))
#             babel_float = np.array([float(babel_arg) for babel_arg in parsed.babel_args])
#             cam.projection().setParameters(babel_float)
#         else:
#             cam = evalCams[i]
#             print(str(cam.projection().getParameters()))
#             obslist = evalObs[i]
#             err = np.array(cam.getReprojectionErrors(evalObs[i]))
#         # delete_idxs = get_outliers(err,4)
#         # err = np.delete(err,delete_idxs,1)
#         num_points = np.shape(err)[0]
#         num_points = len(err)
#         err = np.transpose(np.reshape(err,(num_points,2)))
#         mag_err = np.array([np.linalg.norm(err[:,q]) for q in range(num_points)])
#         # remove outliers
        

#         # get polar angles
#         img_points_x = np.concatenate([obs.getCornersImageFrame()[:,0] for obs in obslist ])
#         img_points_y = np.concatenate([obs.getCornersImageFrame()[:,1] for obs in obslist ])

#         counter = 0

#         polars = []
#         for j in range(num_points):
#             ec_pnt = cam.keypointToEuclidean(np.array([[img_points_x[j]],[img_points_y[j]]])) 
#             polars.append(np.rad2deg(np.arctan2(np.linalg.norm(ec_pnt[:2]),ec_pnt[-1])))

#         polars = np.array(polars)
        

#         polar_idxs = np.argsort(polars)
#         # print("Stats for {0}, mean: {1}, std:{2}.".format(camNames[i],str(np.mean(np.array(err), 1, dtype=np.float)),str(np.std(np.array(err), 1, dtype=np.float))))
#         err = np.array([np.linalg.norm(err[:,q]) for q in range(num_points)])

#         polar_min = np.min(polars)
#         polar_max = np.max(polars)
#         bins = np.linspace(polar_min,polar_max,parsed.nbins+1)
#         bin_idxs = np.digitize(polars,bins)-1
#         bin_avg_errs = np.zeros(parsed.nbins)
#         empty_idxs = []
#         full_idxs = []
#         for k in range(parsed.nbins):
#             if len(np.where(bin_idxs==k)[0]!=0):
#                 bin_avg_errs[k] = np.mean(err[np.where(bin_idxs==k)])
#                 full_idxs.append(k)
#             else:
#                 empty_idxs.append(k)
        
#         bin_avg_errs = np.delete(bin_avg_errs,empty_idxs)
#         # print("normal reproj. err:"+str(np.mean(err)))
#         # print("spatially normalized: "+str(np.mean(bin_avg_errs)))


def main():



    log_files = glob.glob(experiments_folder+'/**/*.pkl')
    experiments = [Experiment(log_file) for log_file in log_files]

    
    column_names = ['log_dir','bag','target','num_cams','model','kalibr error','kalibr normalized err','tartan err','tartan normalized err','kalibr evaluated on tartan','kalibr on tartan normalized','kalibr num corners','tartan num corners']
    df=pd.DataFrame(columns=column_names)

    for experiment in experiments:
        print(experiment.dirname)
        results = get_metrics(experiment)
        df_entry = [experiment.dirname,experiment.bag,experiment.target,experiment.num_cams,experiment.model,results[0],results[1],results[2],results[3],results[4],results[5],results[6],results[7]]
        df.loc[len(df.index)] = df_entry
        df.to_csv(csv_dir)


    # assert len(parsed.cams) == len(parsed.cams_it)
    # num_evals = len(parsed.cams)

    # with open(parsed.obsfile, 'r') as f:
    #     tartanObsLog = dill.load(f)
    
    # tartanObs = [log.ObservationDatabase_ for log in tartanObsLog]
    # print("obs iteration:"+str(parsed.obs_it))
    # evalObsDb = tartanObs[int(parsed.obs_it)]
    # evalObs = [evalObsDb.observations[int(parsed.camids[i])] for i in range(num_evals)]

    # evalCams = []
    
    # if type(parsed.camids) != list:
    #     cam_list = []
    #     for _ in range(num_evals):
    #         cam_list.append(int(parsed.camids))
    # else:
    #     cam_list = list(map(int,parsed.camids))
    
    # for i, camlog in enumerate(parsed.cams):
    #     with open(camlog, 'r') as f:
    #         camlog_loaded = dill.load(f)

    #     evalCams.append(camlog_loaded[int(parsed.cams_it[i])].camera_[cam_list[i]])


    # plotSphericalError(evalObs,evalCams,parsed.cam_names,parsed)


if __name__ == "__main__":
    main()
